from functools import reduce
from heapq import merge
from telnetlib import IP
import networkx as nx
import argparse
import os

from DFA import DFA, State

class DFAGenerator:
    def __init__(self, env, k=2):
        self.k = k
        self.env = env
        self.topo = nx.MultiGraph()
        self.merged = None
        self.node_to_dfas = {}
        for line in open(env + '/topo.txt'):
            arr = line.split()
            self.topo.add_edge(arr[0], arr[2], portmap={arr[0]: arr[1], arr[2]: arr[3]})
        # i = 0
        # for line in open(env + '/gen/requirements.txt'):
        #     arr = line.split()
        #     if arr[5] == 'ecmp':
        #         self.ecmp(arr[3], arr[4], arr[6])
        #     i += 1
        #     print('%d dfa generated' % i)

    def gen_dfa_dst(self):
        for iPod in range(self.k):
            for iRsw in range(48):
                dst_rsw = 'rsw-%d-%d' % (iPod, iRsw)
                dfa = nx.DiGraph()
                for node in self.topo:
                    arr = node.split('-')
                    if arr[0] == 'rsw' and node != dst_rsw: # rsw send to fsw
                        for i in range(4):
                            dfa.add_edge(node, 'fsw-%s-%d' % (arr[1], i))
                    if arr[0] == 'fsw':
                        if int(arr[1]) == iPod:
                            dfa.add_edge(node, dst_rsw)
                        else:
                            for i in range(48):
                                dfa.add_edge(node, 'ssw-%s-%d' % (arr[2], i))
                    if arr[0] == 'ssw':
                        dfa.add_edge(node, 'fsw-%d-%s' % (iPod, arr[1]))
                self.save_dfa(dfa, dst_rsw + '.txt')

    def merge_dfa(self, dfa):
        if self.merged == None:
            self.merged = dfa
            return
        for node in dfa.nodes:
            pass

    def get_or_create(self, key):
        if key not in self.node_to_dfas:
            self.node_to_dfas[key] = []
        return self.node_to_dfas[key]

    def gen_dfa_opt_ft(self):
        G = nx.DiGraph()
        hk = int(self.k / 2)
        for iPod in range(self.k):
            for iRsw in range(hk):
                rsw_inter = 'rsw-%d-%d:0' % (iPod, iRsw)
                for iFsw in range(hk):
                    G.add_edge(rsw_inter, 'fsw-%d-%d:100' % (iPod, iFsw))
                    
                    for jRsw in range(hk):
                        if jRsw == iRsw:
                            continue
                        G.add_edge('rsw-%d-%d:0' % (iPod, iRsw), 'fsw-%d-%d:%d' % (iPod, iFsw, jRsw))
            for iFsw in range(hk):
                for iSsw in range(hk):
                    for x in range(self.k):
                        if x == iPod:
                            continue
                        G.add_edge('fsw-%d-%d:100' % (iPod, iFsw), 'ssw-%d-%d:%d' % (iFsw, iSsw, x))

                for iRsw in range(hk):
                    G.add_edge('fsw-%d-%d:%d' % (iPod, iFsw, iRsw), 'rsw-%d-%d:1' % (iPod, iRsw))
        for iSpine in range(hk):
            for iSsw in range(hk):
                for x in range(self.k):
                    for y in range(hk):
                        G.add_edge('ssw-%d-%d:%d' % (iSpine, iSsw, x), 'fsw-%d-%d:%d' % (x, iSpine, y))
        print(len(G.nodes))
        print(len(G.edges))
        
        

        opt_dfa_file = self.env + '/gen/dfas/opt_dfa.txt'
        os.makedirs(os.path.dirname(opt_dfa_file), exist_ok=True)
        with open(opt_dfa_file, 'w') as f:
            for e in G.edges:
                f.write('%s %s\n' % (e[0], e[1]))

    def gen_dfa_opt(self):
        G = nx.DiGraph()
        for iPod in range(self.k):
            for iRsw in range(48):
                rsw_inter = 'rsw-%d-%d:0' % (iPod, iRsw)
                for iFsw in range(4):
                    G.add_edge(rsw_inter, 'fsw-%d-%d:100' % (iPod, iFsw))
                    
                    # for i in range(self.k):
                    #     if i == iPod:
                    #         continue
                    #     for j in range(48):
                    #         self.get_or_create(rsw_inter).append('rsw-%d-%d' % (i, j))
                    for jRsw in range(48):
                        if jRsw == iRsw:
                            continue
                        # TODO: check it!
                        G.add_edge('rsw-%d-%d:0' % (iPod, iRsw), 'fsw-%d-%d:%d' % (iPod, iFsw, jRsw))
                        
                        # for i in range(48):
                        #     if i == iRsw:
                        #         continue
                        #     self.get_or_create('rsw-%d-%d:1' % (iPod, iRsw)).append('rsw-%d-%d' % (iPod, i))
            for iFsw in range(4):
                for iSsw in range(48):
                    for x in range(self.k):
                        if x == iPod:
                            continue
                        G.add_edge('fsw-%d-%d:100' % (iPod, iFsw), 'ssw-%d-%d:%d' % (iFsw, iSsw, x))

                for iRsw in range(48):
                    G.add_edge('fsw-%d-%d:%d' % (iPod, iFsw, iRsw), 'rsw-%d-%d:2' % (iPod, iRsw))
        for iSpine in range(4):
            for iSsw in range(48):
                for x in range(self.k):
                    for y in range(48):
                        G.add_edge('ssw-%d-%d:%d' % (iSpine, iSsw, x), 'fsw-%d-%d:%d' % (x, iSpine, y))
        print(len(G.nodes))
        print(len(G.edges))
        

        opt_dfa_file = self.env + '/gen/dfas/opt_dfa.txt'
        os.makedirs(os.path.dirname(opt_dfa_file), exist_ok=True)
        with open(opt_dfa_file, 'w') as f:
            for e in G.edges:
                f.write('%s %s\n' % (e[0], e[1]))
        

    def save_dfa(self, dfa, name):
        dfa_file = self.env + '/gen/dfas/' + name
        os.makedirs(os.path.dirname(dfa_file), exist_ok=True)
        with open(dfa_file, 'w') as f:
            for e in dfa.edges:
                f.write('%s %s\n' % (e[0], e[1]))

    def ecmp(self, src, dst, dfa_id):
        paths = [p for p in nx.all_shortest_paths(self.topo, src, dst)]
        print('ecmp paths: %d' % len(paths))
        dfa = nx.DiGraph()
        for path in paths:
            for i in range(len(path) - 1):
                s = path[i]
                d = path[i + 1]
                if (s, d) not in dfa.edges:
                    dfa.add_edge(s, d)
                    
        print('DFA edges: %d' % len(dfa.edges))
        dfa_file = self.env + '/dfas/' + str(dfa_id) + '.txt'
        os.makedirs(os.path.dirname(dfa_file), exist_ok=True)
        with open(dfa_file, 'w') as f:
            for e in dfa.edges:
                f.write('%s %s\n' % (e[0], e[1]))

    def gen_dfa_reach(self, cutoff=None, opt=True):
        # print(self.topo['chic']['atla'])
        # return
        raw_states = 0
        raw_trans = 0
        merged_dfas = []
        for src in self.topo.nodes:
            for dst in self.topo.nodes:
                if src == dst:
                    continue
                print(src, dst)
                paths = nx.all_simple_paths(self.topo, src, dst, cutoff)
                # ps = [list(p) for p in map(nx.utils.pairwise, paths)]
                ps = set(
                    tuple(path)  # Sets can't be used with lists because they are not hashable
                    for path in paths
                )
                
                ps = [list(p) for p in map(nx.utils.pairwise, ps)]
                
                dfas = []
                for path in ps:
                    dfa = DFA()
                    s = dfa.add_state(path[0][0])
                    dfa.add_transition(dfa.start, s, s.device)
                    last = s
                    for pair in path:
                        ns = dfa.add_state(pair[1])
                        dfa.add_transition(last, ns, ns.device)
                        last = ns
                    dfas.append(dfa)
                    # raw_states += len(dfa.graph.nodes)
                    # raw_trans += len(dfa.graph.edges)

                for dfa in dfas:
                    print(dfa.graph)
                print(sum([len(x.graph.nodes) for x in dfas]))
                merged_dfa = reduce(lambda x, y: self.merge(x, y), dfas)
                # merged_dfa.remove_unconncted()
                print(merged_dfa.graph)
                # merged_dfa.draw()
                while self.minimize(merged_dfa):
                    pass
                # merged_dfa.remove_unconncted()
                merged_dfas.append(merged_dfa)
                raw_states += len(merged_dfa.graph.nodes)
                raw_trans += len(merged_dfa.graph.edges)
                # merged_dfa.draw()
                # return
        if opt:
            opt_dfa = reduce(lambda x, y: self.merge(x, y), merged_dfas)
            while self.minimize(opt_dfa):
                pass
            # opt_dfa.remove_unconncted()
            # opt_dfa.draw()
            print(raw_states, raw_trans)
            print(len(opt_dfa.graph.nodes), len(opt_dfa.graph.edges))

            opt_dfa_file = self.env + '/gen/dfas/opt_dfa.txt'
            os.makedirs(os.path.dirname(opt_dfa_file), exist_ok=True)
            with open(opt_dfa_file, 'w') as f:
                for e in opt_dfa.graph.edges:
                    if e[0] == opt_dfa.start:
                        continue
                    ports = [x['portmap'][e[0].device] for x in self.topo[e[0].device][e[1].device].values()]
                    f.write('%s %s %s\n' % (e[0], e[1], ','.join(ports)))
        else:
            i = 0
            for dfa in merged_dfas:
                dfa_file = self.env + ('/gen/noptdfas/%d.txt' % i)
                os.makedirs(os.path.dirname(dfa_file), exist_ok=True)
                i += 1
                with open(dfa_file, 'w') as f:
                    for e in dfa.graph.edges:
                        if e[0] == dfa.start:
                            continue
                        ports = [x['portmap'][e[0].device] for x in self.topo[e[0].device][e[1].device].values()]
                        f.write('%s %s %s\n' % (e[0], e[1], ','.join(ports)))

                

    def merge(self, pdfa: DFA, cdfa: DFA):
        # psrcs = [node for node in pdfa.nodes if pdfa.in_degree(node) == 0]
        # csrcs = [node for node in cdfa.nodes if cdfa.in_degree(node) == 0]
        for next_state in cdfa.transitions[cdfa.start]:
            self.traverse(pdfa, cdfa, pdfa.start, next_state, True)
        return pdfa

    def traverse(self, pdfa: DFA, cdfa: DFA, pcurr: State, curr: State, father_merged: bool):
        if father_merged == False:
            state = pdfa.add_state(curr.device)
            pdfa.add_transition(pcurr, state, state.device)
            pcurr = state
        else:
            pnexts = [x for x in pdfa.transitions[pcurr]]
            can_merge = False
            for pnext in pnexts:
                if curr.device == pnext.device:
                    pcurr = pnext
                    can_merge = True
                    break
            if can_merge == False:
                state = pdfa.add_state(curr.device)
                pdfa.add_transition(pcurr, state, state.device)
                pcurr = state
                father_merged = False

        for next_state in cdfa.transitions[curr]:
            self.traverse(pdfa, cdfa, pcurr, next_state, father_merged)

    def minimize(self, dfa: DFA):
        mined = False
        x = None
        y = None
        for outer in nx.bfs_predecessors(dfa.graph, dfa.start):
            for inner in nx.bfs_predecessors(dfa.graph, dfa.start):
                x = outer[0]
                y = inner[0]
                if x == y:
                    continue
                if x.device == y.device:
                    if dfa.serilize(x) == dfa.serilize(y):
                        mined = True
                        break
            else:
                continue
            break
        pres = dfa.graph.predecessors(y)
        for pre in list(pres):
            dfa.remove_transition(pre, y)
            dfa.add_transition(pre, x, x.device)
            # dfa.graph.remove_edge(pre, y)
            # dfa.graph.add_edge(pre, x)
        
        dfa.remove_unconncted()
        return mined

        # return mined
        # dsts = [node for node in dfa.graph.nodes if dfa.graph.out_degree(node) == 0]
        # print(len(dsts))



def gen_req(env, k):
    dfa_id = 0
    req_file = env + '/gen/requirements.txt'
    os.makedirs(os.path.dirname(req_file), exist_ok=True)
    with open(req_file, 'w') as f:
        if 'fb' in env:
            for iPod in range(k):
                for iRsw in range(48):
                    src = 'rsw-%d-%d' % (iPod, iRsw)
                    for jPod in range(k):
                        for jRsw in range(48):
                            if iPod == jPod and iRsw == jRsw:
                                continue
                            dst = 'rsw-%d-%d' % (jPod, jRsw)
                            ip = (jPod << 24) + (jRsw << 16)
                            str = '%d 16 %s %s %s ecmp dfa%d\n' % (ip, src, src, dst, dfa_id)
                            dfa_id += 1
                            f.write(str)

def gen_req_fb(env, k):
    dfa_id = 0
    req_file = env + '/gen/requirements.txt'
    os.makedirs(os.path.dirname(req_file), exist_ok=True)
    with open(req_file, 'w') as f:
        if 'fb' in env:
            for iPod in range(k):
                for iRsw in range(48):
                    dst = 'rsw-%d-%d' % (iPod, iRsw)
                    ip = (iPod << 24) + (iRsw << 16)
                    str = '%d 16 %s.txt\n' % (ip, dst)
                    f.write(str)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("env", help="the env dir")
    parser.add_argument("-fbo", action="store_true", help="gen fb opt dfa")
    parser.add_argument("-r", action="store_true", help="gen requirements")
    parser.add_argument("-d", action="store_true", help="gen dfa for each requirement")
    parser.add_argument("-fb", action="store_true", help="gen opt dfas for fabric")
    parser.add_argument("-ft", action="store_true", help="gen opt dfas for fattree")
    parser.add_argument("-k", type=int, default=2, help="npods of fabric")
    parser.add_argument("-simple", action="store_true", help="gen simple reachability dfa")
    parser.add_argument("-cutoff", type=int, help="simple path length")
    parser.add_argument("-nopt", action="store_true", help="gen nopt")
    args = parser.parse_args()
    if args.simple:
        df = DFAGenerator(args.env)
        if args.nopt:
            df.gen_dfa_reach(args.cutoff, False)
        else:
            df.gen_dfa_reach(args.cutoff)
    if args.ft:
        df = DFAGenerator(args.env, args.k)
        df.gen_dfa_opt_ft()
    if args.fbo:
        df = DFAGenerator(args.env, args.k)
        df.gen_dfa_opt()
    if args.r:
        gen_req_fb(args.env, args.k)
    if args.d:
        dg = DFAGenerator(args.env, args.k)
    if args.fb:
        dg = DFAGenerator(args.env, args.k)
        dg.gen_dfa_dst()
    